{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1 align=\"center\">SimpleITK Spatial Transformations</h1>\n",
    "\n",
    "\n",
    "**Summary:**\n",
    "\n",
    "1. Points are represented by vector-like data types: Tuple, Numpy array, List.\n",
    "2. Matrices are represented by vector-like data types in row major order.\n",
    "3. Default transformation initialization as the identity transform.\n",
    "4. Angles specified in radians, distances specified in unknown but consistent units (nm,mm,m,km...).\n",
    "5. All global transformations **except translation** are of the form:\n",
    "$$T(\\mathbf{x}) = A(\\mathbf{x}-\\mathbf{c}) + \\mathbf{t} + \\mathbf{c}$$\n",
    "\n",
    "   Nomenclature (when printing your transformation):\n",
    "\n",
    "   * Matrix: the matrix $A$\n",
    "   * Center: the point $\\mathbf{c}$\n",
    "   * Translation: the vector $\\mathbf{t}$\n",
    "   * Offset: $\\mathbf{t} + \\mathbf{c} - A\\mathbf{c}$\n",
    "6. Bounded transformations, BSplineTransform and DisplacementFieldTransform, behave as the identity transform outside the defined bounds.\n",
    "7. DisplacementFieldTransform:\n",
    "   * Initializing the DisplacementFieldTransform using an image requires that the image's pixel type be sitk.sitkVectorFloat64.\n",
    "   * Initializing the DisplacementFieldTransform using an image will \"clear out\" your image (your alias to the image will point to an empty, zero sized, image).\n",
    "8. Composite transformations are applied in stack order (first added, last applied)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformation Types\n",
    "\n",
    "SimpleITK supports the following transformation types.\n",
    "\n",
    "| Class Name | Details|\n",
    "|:-------------|:---------|\n",
    "|[TranslationTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1TranslationTransform.html) | 2D or 3D, translation|\n",
    "|[VersorTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1VersorTransform.html)| 3D, rotation represented by a versor|\n",
    "|[VersorRigid3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1VersorRigid3DTransform.html)|3D, rigid transformation with rotation represented by a versor|\n",
    "|[Euler2DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1Euler2DTransform.html)| 2D, rigid transformation with rotation represented by a Euler angle|\n",
    "|[Euler3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1Euler3DTransform.html)| 3D, rigid transformation with rotation represented by Euler angles|\n",
    "|[Similarity2DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1Similarity2DTransform.html)| 2D, composition of isotropic scaling and rigid transformation with rotation represented by a Euler angle|\n",
    "|[Similarity3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1Similarity3DTransform.html) | 3D, composition of isotropic scaling and rigid transformation with rotation represented by a versor|\n",
    "|[ScaleTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1ScaleTransform.html)|2D or 3D, anisotropic scaling|\n",
    "|[ScaleVersor3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1ScaleVersor3DTransform.html)| 3D, rigid transformation and anisotropic scale is **added** to the rotation matrix part (not composed as one would expect)|\n",
    "|[ScaleSkewVersor3DTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1ScaleSkewVersor3DTransform.html#details)|3D, rigid transformation with anisotropic scale and skew matrices **added** to the rotation matrix part (not composed as one would expect) |\n",
    "|[AffineTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1AffineTransform.html)| 2D or 3D, affine transformation|\n",
    "|[BSplineTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1BSplineTransform.html)|2D or 3D, deformable transformation represented by a sparse regular grid of control points |\n",
    "|[DisplacementFieldTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1DisplacementFieldTransform.html)| 2D or 3D, deformable transformation represented as a dense regular grid of vectors|\n",
    "|[CompositeTransform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1CompositeTransform.html)| 2D or 3D, stack of transformations concatenated via composition, last added, first applied|\n",
    "|[Transform](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1Transform.html#details) | 2D or 3D, parent/super-class for all transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import SimpleITK as sitk\n",
    "import utilities as util\n",
    "\n",
    "import numpy as np\n",
    "%matplotlib inline \n",
    "import matplotlib.pyplot as plt\n",
    "from ipywidgets import interact, fixed\n",
    "\n",
    "OUTPUT_DIR = \"output\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will introduce the transformation types, starting with translation and illustrating how to move from a lower to higher parameter space (e.g. translation to rigid).  \n",
    "\n",
    "We start with the global transformations. All of them <b>except translation</b> are of the form:\n",
    "$$T(\\mathbf{x}) = A(\\mathbf{x}-\\mathbf{c}) + \\mathbf{t} + \\mathbf{c}$$\n",
    "\n",
    "In ITK speak (when printing your transformation):\n",
    "<ul>\n",
    "<li>Matrix: the matrix $A$</li>\n",
    "<li>Center: the point $\\mathbf{c}$</li>\n",
    "<li>Translation: the vector $\\mathbf{t}$</li>\n",
    "<li>Offset: $\\mathbf{t} + \\mathbf{c} - A\\mathbf{c}$</li>\n",
    "</ul>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TranslationTransform\n",
    "\n",
    "Create a translation and then transform a point and use the inverse transformation to get the original back."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dimension = 2        \n",
    "offset = [2]*dimension # use a Python trick to create the offset list based on the dimension\n",
    "translation = sitk.TranslationTransform(dimension, offset)\n",
    "print(translation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "point = [10, 11] if dimension==2 else [10, 11, 12] # set point to match dimension\n",
    "transformed_point = translation.TransformPoint(point)\n",
    "translation_inverse = translation.GetInverse()\n",
    "print('original point: ' + util.point2str(point) + '\\n'\n",
    "      'transformed point: ' + util.point2str(transformed_point) + '\\n'\n",
    "      'back to original: ' + util.point2str(translation_inverse.TransformPoint(transformed_point)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Euler2DTransform\n",
    "\n",
    "Rigidly transform a 2D point using a Euler angle parameter specification.\n",
    "\n",
    "Notice that the dimensionality of the Euler angle based rigid transformation is associated with the class, unlike the translation which is set at construction.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "point = [10, 11]\n",
    "rotation2D = sitk.Euler2DTransform()\n",
    "rotation2D.SetTranslation((7.2, 8.4))\n",
    "rotation2D.SetAngle(np.pi/2)\n",
    "print('original point: ' + util.point2str(point) + '\\n'\n",
    "      'transformed point: ' + util.point2str(rotation2D.TransformPoint(point)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VersorTransform (rotation in 3D)\n",
    "\n",
    "Rotation using a versor, vector part of unit quaternion, parametrization. Quaternion defined by rotation of $\\theta$ radians around axis $n$, is $q = [n*\\sin(\\frac{\\theta}{2}), \\cos(\\frac{\\theta}{2})]$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use a versor:\n",
    "rotation1 = sitk.VersorTransform([0,0,1,0])\n",
    "\n",
    "# Use axis-angle:\n",
    "rotation2 = sitk.VersorTransform((0,0,1), np.pi)\n",
    "\n",
    "# Use a matrix:\n",
    "rotation3 = sitk.VersorTransform()\n",
    "rotation3.SetMatrix([-1, 0, 0, 0, -1, 0, 0, 0, 1]);\n",
    "\n",
    "point = (10, 100, 1000)\n",
    "\n",
    "p1 = rotation1.TransformPoint(point)\n",
    "p2 = rotation2.TransformPoint(point)\n",
    "p3 = rotation3.TransformPoint(point)\n",
    "\n",
    "print('Points after transformation:\\np1=' + str(p1) + \n",
    "      '\\np2='+ str(p2) + '\\np3='+ str(p3))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Translation to Rigid [3D]\n",
    "\n",
    "We only need to copy the translational component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dimension = 3        \n",
    "t =(1,2,3) \n",
    "translation = sitk.TranslationTransform(dimension, t)\n",
    "\n",
    "# Copy the translational component.\n",
    "rigid_euler = sitk.Euler3DTransform()\n",
    "rigid_euler.SetTranslation(translation.GetOffset())\n",
    "\n",
    "# Apply the transformations to the same set of random points and compare the results.\n",
    "util.print_transformation_differences(translation, rigid_euler)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rotation to Rigid [3D]\n",
    "Copy the matrix or versor and <b>center of rotation</b>."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rotation_center = (10, 10, 10)\n",
    "rotation = sitk.VersorTransform([0,0,1,0], rotation_center)\n",
    "\n",
    "rigid_versor = sitk.VersorRigid3DTransform()\n",
    "rigid_versor.SetRotation(rotation.GetVersor())\n",
    "#rigid_versor.SetCenter(rotation.GetCenter()) #intentional error, not copying center of rotation\n",
    "\n",
    "# Apply the transformations to the same set of random points and compare the results.\n",
    "util.print_transformation_differences(rotation, rigid_versor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the cell above, when we don't copy the center of rotation we have a constant error vector, $\\mathbf{c}$ - A$\\mathbf{c}$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Similarity [2D]\n",
    "\n",
    "When the center of the similarity transformation is not at the origin the effect of the transformation is not what most of us expect. This is readily visible if we limit the transformation to scaling: $T(\\mathbf{x}) = s\\mathbf{x}-s\\mathbf{c} + \\mathbf{c}$. Changing the transformation's center results in scale + translation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_center_effect(x, y, tx, point_list, xlim, ylim):\n",
    "    tx.SetCenter((x,y))\n",
    "    transformed_point_list = [ tx.TransformPoint(p) for p in point_list]\n",
    "\n",
    "    plt.scatter(list(np.array(transformed_point_list).T)[0],\n",
    "                list(np.array(transformed_point_list).T)[1],\n",
    "                marker='^', \n",
    "                color='red', label='transformed points')\n",
    "    plt.scatter(list(np.array(point_list).T)[0],\n",
    "                list(np.array(point_list).T)[1],\n",
    "                marker='o', \n",
    "                color='blue', label='original points')\n",
    "    plt.xlim(xlim)\n",
    "    plt.ylim(ylim)\n",
    "    plt.legend(loc=(0.25,1.01))\n",
    "\n",
    "# 2D square centered on (0,0)\n",
    "points = [np.array((-1.0,-1.0)), np.array((-1.0,1.0)), np.array((1.0,1.0)), np.array((1.0,-1.0))]\n",
    "\n",
    "# Scale by 2 \n",
    "similarity = sitk.Similarity2DTransform();\n",
    "similarity.SetScale(2)\n",
    "\n",
    "interact(display_center_effect, x=(-10,10), y=(-10,10),tx = fixed(similarity), point_list = fixed(points), \n",
    "         xlim = fixed((-10,10)),ylim = fixed((-10,10)));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rigid to Similarity [3D]\n",
    "Copy the translation, center, and matrix or versor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rotation_center = (100, 100, 100)\n",
    "theta_x = 0.0\n",
    "theta_y = 0.0\n",
    "theta_z = np.pi/2.0\n",
    "translation = (1,2,3)\n",
    "\n",
    "rigid_euler = sitk.Euler3DTransform(rotation_center, theta_x, theta_y, theta_z, translation)\n",
    "\n",
    "similarity = sitk.Similarity3DTransform()\n",
    "similarity.SetMatrix(rigid_euler.GetMatrix())\n",
    "similarity.SetTranslation(rigid_euler.GetTranslation())\n",
    "similarity.SetCenter(rigid_euler.GetCenter())\n",
    "\n",
    "# Apply the transformations to the same set of random points and compare the results.\n",
    "util.print_transformation_differences(rigid_euler, similarity)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Similarity to Affine [3D]\n",
    "Copy the translation, center and matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rotation_center = (100, 100, 100)\n",
    "axis = (0,0,1)\n",
    "angle = np.pi/2.0\n",
    "translation = (1,2,3)\n",
    "scale_factor = 2.0\n",
    "similarity = sitk.Similarity3DTransform(scale_factor, axis, angle, translation, rotation_center)\n",
    "\n",
    "affine = sitk.AffineTransform(3)\n",
    "affine.SetMatrix(similarity.GetMatrix())\n",
    "affine.SetTranslation(similarity.GetTranslation())\n",
    "affine.SetCenter(similarity.GetCenter())\n",
    "\n",
    "# Apply the transformations to the same set of random points and compare the results.\n",
    "util.print_transformation_differences(similarity, affine)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Scale Transform\n",
    "\n",
    "Just as the case was for the similarity transformation above, when the transformations center is not at the origin, instead of a pure anisotropic scaling we also have translation ($T(\\mathbf{x}) = \\mathbf{s}^T\\mathbf{x}-\\mathbf{s}^T\\mathbf{c} + \\mathbf{c}$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2D square centered on (0,0).\n",
    "points = [np.array((-1.0,-1.0)), np.array((-1.0,1.0)), np.array((1.0,1.0)), np.array((1.0,-1.0))]\n",
    "\n",
    "# Scale by half in x and 2 in y.\n",
    "scale = sitk.ScaleTransform(2, (0.5,2));\n",
    "\n",
    "# Interactively change the location of the center.\n",
    "interact(display_center_effect, x=(-10,10), y=(-10,10),tx = fixed(scale), point_list = fixed(points), \n",
    "         xlim = fixed((-10,10)),ylim = fixed((-10,10)));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unintentional Misnomers (originally from ITK)\n",
    "\n",
    "Two transformation types whose names may mislead you are ScaleVersor and ScaleSkewVersor. Basing your choices on expectations without reading the documentation will surprise you.\n",
    "\n",
    "ScaleVersor -  based on name expected a composition of transformations, in practice it is:\n",
    "$$T(x) = (R+S)(\\mathbf{x}-\\mathbf{c}) + \\mathbf{t} + \\mathbf{c},\\;\\; \\textrm{where } S= \\left[\\begin{array}{ccc} s_0-1 & 0 & 0 \\\\ 0 & s_1-1 & 0 \\\\ 0 & 0 & s_2-1 \\end{array}\\right]$$ \n",
    "\n",
    "ScaleSkewVersor - based on name expected a composition of transformations, in practice it is:\n",
    "$$T(x) = (R+S+K)(\\mathbf{x}-\\mathbf{c}) + \\mathbf{t} + \\mathbf{c},\\;\\; \\textrm{where } S = \\left[\\begin{array}{ccc} s_0-1 & 0 & 0 \\\\ 0 & s_1-1 & 0 \\\\ 0 & 0 & s_2-1 \\end{array}\\right]\\;\\; \\textrm{and } K = \\left[\\begin{array}{ccc} 0 & k_0 & k_1 \\\\ k_2 & 0 & k_3 \\\\ k_4 & k_5 & 0 \\end{array}\\right]$$ \n",
    "\n",
    "Note that ScaleSkewVersor is is an over-parametrized version of the affine transform, 15 parameters (scale, skew, versor, translation) vs. 12 parameters (matrix, translation)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bounded Transformations\n",
    "\n",
    "SimpleITK supports two types of bounded non-rigid transformations, BSplineTransform (sparse representation) and \tDisplacementFieldTransform (dense representation).\n",
    "\n",
    "Transforming a point that is outside the bounds will return the original point - identity transform."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BSpline\n",
    "Using a sparse set of control points to control a free form deformation. Note that the order of parameters to the transformation is $[x_0\\ldots x_N,y_0\\ldots y_N, z_0\\ldots z_N]$ for $N$ control points.\n",
    "\n",
    "\n",
    "To configure this transformation type we need to specify its bounded domain and the parameters for the control points, the incremental shifts from original grid positions. This can either be done explicitly by specifying the set of parameters defining the domain and control point parameters one by one or by using a set of images that encode all of this information in a more compact manner.\n",
    "\n",
    "The next two code cells illustrate these two options."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the transformation (when working with images it is easier to use the BSplineTransformInitializer function\n",
    "# or its object oriented counterpart BSplineTransformInitializerFilter).\n",
    "dimension = 2\n",
    "spline_order = 3\n",
    "direction_matrix_row_major = [1.0,0.0,0.0,1.0] # identity, mesh is axis aligned\n",
    "origin = [-1.0,-1.0]  \n",
    "domain_physical_dimensions = [2,2]\n",
    "mesh_size = [4,3]\n",
    "\n",
    "bspline = sitk.BSplineTransform(dimension, spline_order)\n",
    "bspline.SetTransformDomainOrigin(origin)\n",
    "bspline.SetTransformDomainDirection(direction_matrix_row_major)\n",
    "bspline.SetTransformDomainPhysicalDimensions(domain_physical_dimensions)\n",
    "bspline.SetTransformDomainMeshSize(mesh_size)\n",
    "\n",
    "# Random displacement of the control points, specifying the x and y\n",
    "# displacements separately allows us to play with these parameters,\n",
    "# just multiply one of them with zero to see the effect.\n",
    "x_displacement = np.random.random(len(bspline.GetParameters())//2)\n",
    "y_displacement = np.random.random(len(bspline.GetParameters())//2)\n",
    "original_control_point_displacements = np.concatenate([x_displacement, y_displacement])\n",
    "bspline.SetParameters(original_control_point_displacements)\n",
    "\n",
    "# Apply the BSpline transformation to a grid of points \n",
    "# starting the point set exactly at the origin of the BSpline mesh is problematic as\n",
    "# these points are considered outside the transformation's domain,\n",
    "# remove epsilon below and see what happens.\n",
    "numSamplesX = 10\n",
    "numSamplesY = 20\n",
    "                   \n",
    "coordsX = np.linspace(origin[0]+np.finfo(float).eps, origin[0] + domain_physical_dimensions[0], numSamplesX)\n",
    "coordsY = np.linspace(origin[1]+np.finfo(float).eps, origin[1] + domain_physical_dimensions[1], numSamplesY)\n",
    "XX, YY = np.meshgrid(coordsX, coordsY)\n",
    "\n",
    "interact(util.display_displacement_scaling_effect, s= (-1.5,1.5), original_x_mat = fixed(XX), original_y_mat = fixed(YY),\n",
    "         tx = fixed(bspline), original_control_point_displacements = fixed(original_control_point_displacements));            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We next define the same BSpline transformation using a set of coefficient images. Note that to compare the parameter values for the two transformations we need to scale the values in the new transformation using the scale value used in the GUI above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "simpleitk_error_expected": "invalid syntax"
   },
   "outputs": [],
   "source": [
    "control_point_number = [sz+spline_order for sz in mesh_size]\n",
    "num_parameters_per_axis = np.prod(control_point_number)\n",
    "\n",
    "coefficient_images = []\n",
    "for i in range(dimension):\n",
    "    coefficient_image = sitk.GetImageFromArray((original_control_point_displacements[i*num_parameters_per_axis:(i+1)*num_parameters_per_axis]).reshape(control_point_number))\n",
    "    coefficient_image.SetOrigin(origin)\n",
    "    coefficient_image.SetSpacing([sz/(cp-1) for cp,sz in zip(control_point_number, domain_physical_dimensions)])\n",
    "    coefficient_image.SetDirection(direction_matrix_row_major)\n",
    "    coefficient_images.append(coefficient_image)\n",
    "\n",
    "bspline2 = sitk.BSplineTransform(coefficient_images, spline_order)\n",
    "\n",
    "# Show that the two transformations have the same set of values in the control point parameter space:\n",
    "# Set the scale value based on the slider value in the GUI above.\n",
    "scale_factor_from_gui = \n",
    "print(np.array(bspline.GetParameters()) - np.array(bspline2.GetParameters())*scale_factor_from_gui)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DisplacementField\n",
    "\n",
    "A dense set of vectors representing the displacement inside the given domain. The most generic representation of a transformation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the displacement field. \n",
    "    \n",
    "# When working with images the safer thing to do is use the image based constructor,\n",
    "# sitk.DisplacementFieldTransform(my_image), all the fixed parameters will be set correctly and the displacement\n",
    "# field is initialized using the vectors stored in the image. SimpleITK requires that the image's pixel type be \n",
    "# sitk.sitkVectorFloat64.\n",
    "displacement = sitk.DisplacementFieldTransform(2)\n",
    "field_size = [10,20]\n",
    "field_origin = [-1.0,-1.0]  \n",
    "field_spacing = [2.0/9.0,2.0/19.0]   \n",
    "field_direction = [1,0,0,1] # direction cosine matrix (row major order)     \n",
    "\n",
    "# Concatenate all the information into a single list\n",
    "displacement.SetFixedParameters(field_size+field_origin+field_spacing+field_direction)\n",
    "# Set the interpolator, either sitkLinear which is default or nearest neighbor\n",
    "displacement.SetInterpolator(sitk.sitkNearestNeighbor)\n",
    "\n",
    "originalDisplacements = np.random.random(len(displacement.GetParameters()))\n",
    "displacement.SetParameters(originalDisplacements)\n",
    "\n",
    "coordsX = np.linspace(field_origin[0], field_origin[0]+(field_size[0]-1)*field_spacing[0], field_size[0])\n",
    "coordsY = np.linspace(field_origin[1], field_origin[1]+(field_size[1]-1)*field_spacing[1], field_size[1])\n",
    "XX, YY = np.meshgrid(coordsX, coordsY)\n",
    "\n",
    "interact(util.display_displacement_scaling_effect, s= (-1.5,1.5), original_x_mat = fixed(XX), original_y_mat = fixed(YY),\n",
    "         tx = fixed(displacement), original_control_point_displacements = fixed(originalDisplacements));            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inverting bounded transforms\n",
    "\n",
    "In SimpleITK we cannot directly invert a BSpline transform. Luckily there are several ways to invert a displacement field transform, and **all** transformations can be readily converted to a displacement field. Note though that representing a transformation as a deformation field is an approximation of the original transformation where representation consistency depends on the smoothness of the original transformation and the sampling rate (spacing) of the deformation field.\n",
    "\n",
    "The relevant classes are listed below.\n",
    "* [TransformToDisplacementFieldFilter](https://itk.org/SimpleITKDoxygen/html/classitk_1_1simple_1_1TransformToDisplacementFieldFilter.html)\n",
    "\n",
    "Options for inverting displacement field:\n",
    "* [InvertDisplacementFieldImageFilter](https://itk.org/SimpleITKDoxygen/html/classitk_1_1simple_1_1InvertDisplacementFieldImageFilter.html)\n",
    "* [InverseDisplacementFieldImageFilter](https://itk.org/SimpleITKDoxygen/html/classitk_1_1simple_1_1InverseDisplacementFieldImageFilter.html)\n",
    "* [IterativeInverseDisplacementFieldImageFilter](https://itk.org/SimpleITKDoxygen/html/classitk_1_1simple_1_1IterativeInverseDisplacementFieldImageFilter.html)\n",
    "\n",
    "In the next cell we invert the BSpline transform we worked with above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert the BSpline transform to a displacement field\n",
    "physical_size = bspline.GetTransformDomainPhysicalDimensions()\n",
    "# The deformation field spacing affects the accuracy of the transform approximation, \n",
    "# so we set it here to 0.1mm in all directions.\n",
    "output_spacing = [0.1]*bspline.GetDimension()\n",
    "output_size = [int(phys_sz/spc + 1) for phys_sz, spc in zip(physical_size, output_spacing)]\n",
    "displacement_field_transform = sitk.DisplacementFieldTransform(\n",
    "                                    sitk.TransformToDisplacementField(bspline, \n",
    "                                                                      outputPixelType=sitk.sitkVectorFloat64, \n",
    "                                                                      size=output_size, \n",
    "                                                                      outputOrigin=bspline.GetTransformDomainOrigin(), \n",
    "                                                                      outputSpacing=output_spacing, \n",
    "                                                                      outputDirection=bspline.GetTransformDomainDirection()))\n",
    "\n",
    "# Arbitrary point to evaluate the consistency of the two representations.\n",
    "# Change the value for the \"output_spacing\" above to evaluate its effect\n",
    "# on the transformation representation consistency.\n",
    "pnt = [0.4,-0.2]\n",
    "original_transformed = np.array(bspline.TransformPoint(pnt))\n",
    "secondary_transformed = np.array(displacement_field_transform.TransformPoint(pnt))\n",
    "print('Original transformation result: {0}'.format(original_transformed))\n",
    "print('Deformaiton field transformation result: {0}'.format(secondary_transformed))\n",
    "print('Difference between transformed points is: {0}'.format(np.linalg.norm(original_transformed - secondary_transformed)))\n",
    "\n",
    "# Invert a displacement field transform\n",
    "displacement_image = displacement_field_transform.GetDisplacementField()\n",
    "bspline_inverse_displacement = sitk.DisplacementFieldTransform(\n",
    "                                   sitk.InvertDisplacementField(displacement_image,\n",
    "                                                                maximumNumberOfIterations = 20,\n",
    "                                                                maxErrorToleranceThreshold = 0.01,\n",
    "                                                                meanErrorToleranceThreshold = 0.0001,\n",
    "                                                                enforceBoundaryCondition = True))                         \n",
    "\n",
    "\n",
    "# Transform the point using the original BSpline transformation and then back\n",
    "# via the displacement field inverse.\n",
    "there_and_back = np.array(bspline_inverse_displacement.TransformPoint(bspline.TransformPoint(pnt)))\n",
    "print('Original point: {0}'.format(pnt))\n",
    "print('There and back point: {0}'.format(there_and_back))\n",
    "print('Difference between original and there-and-back points: {0}'.format(np.linalg.norm(pnt - there_and_back)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Composite transform\n",
    "\n",
    "This class represents one or more transformations applied one after the other, with stack semantics (first added last applied). \n",
    "\n",
    "The choice of whether to use a composite transformation or compose transformations on your own has subtle differences in the registration framework.\n",
    "\n",
    "Composite transforms enable a combination of a global transformation with multiple local/bounded transformations. This is useful if we want to apply deformations only in regions that deform while other regions are only effected by the global transformation.\n",
    "\n",
    "The following code illustrates this, where the whole region is translated and subregions have different deformations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Global transformation.\n",
    "translation = sitk.TranslationTransform(2,(1.0,0.0))\n",
    "\n",
    "# Displacement in region 1.\n",
    "displacement1 = sitk.DisplacementFieldTransform(2)\n",
    "field_size = [10,20]\n",
    "field_origin = [-1.0,-1.0]  \n",
    "field_spacing = [2.0/9.0,2.0/19.0]   \n",
    "field_direction = [1,0,0,1] # direction cosine matrix (row major order)     \n",
    "\n",
    "# Concatenate all the information into  a single list.\n",
    "displacement1.SetFixedParameters(field_size+field_origin+field_spacing+field_direction)\n",
    "displacement1.SetParameters(np.ones(len(displacement1.GetParameters())))\n",
    "\n",
    "# Displacement in region 2.\n",
    "displacement2 = sitk.DisplacementFieldTransform(2)\n",
    "field_size = [10,20]\n",
    "field_origin = [1.0,-3]  \n",
    "field_spacing = [2.0/9.0,2.0/19.0]   \n",
    "field_direction = [1,0,0,1] #direction cosine matrix (row major order)     \n",
    "\n",
    "# Concatenate all the information into a single list.\n",
    "displacement2.SetFixedParameters(field_size+field_origin+field_spacing+field_direction)\n",
    "displacement2.SetParameters(-1.0*np.ones(len(displacement2.GetParameters())))\n",
    "\n",
    "# Composite transform which applies the global and local transformations.\n",
    "composite = sitk.CompositeTransform([translation, displacement1, displacement2])\n",
    "\n",
    "# Apply the composite transformation to points in ([-1,-3],[3,1]) and \n",
    "# display the deformation using a quiver plot.\n",
    "        \n",
    "# Generate points.\n",
    "numSamplesX = 10\n",
    "numSamplesY = 10                   \n",
    "coordsX = np.linspace(-1.0, 3.0, numSamplesX)\n",
    "coordsY = np.linspace(-3.0, 1.0, numSamplesY)\n",
    "XX, YY = np.meshgrid(coordsX, coordsY)\n",
    "\n",
    "# Transform points and compute deformation vectors.\n",
    "pointsX = np.zeros(XX.shape)\n",
    "pointsY = np.zeros(XX.shape)\n",
    "for index, value in np.ndenumerate(XX):\n",
    "    px,py = composite.TransformPoint((value, YY[index]))\n",
    "    pointsX[index]=px - value \n",
    "    pointsY[index]=py - YY[index]\n",
    "    \n",
    "plt.quiver(XX, YY, pointsX, pointsY);    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transform\n",
    "\n",
    "This class represents a generic transform and is the return type from the registration framework (if not done in place). Underneath the generic facade is one of the actual classes. To find out who is hiding under the hood we can query the transform to obtain the [TransformEnum](https://simpleitk.org/doxygen/latest/html/namespaceitk_1_1simple.html#a527cb966ed81d0bdc65999f4d2d4d852).\n",
    "\n",
    "We can then downcast the generic transform to its actual type and obtain access to the relevant methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "anonymous_transform_type = sitk.Transform(sitk.TranslationTransform(2,(1.0,0.0)))\n",
    "\n",
    "try:\n",
    "   print(anonymous_transform_type.GetOffset())\n",
    "except:\n",
    "    print('The generic transform does not have this method.')\n",
    "    \n",
    "if anonymous_transform_type.GetTransformEnum() == sitk.sitkTranslation:\n",
    "    translation = sitk.TranslationTransform(anonymous_transform_type)\n",
    "    print(translation.GetOffset())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing and Reading\n",
    "\n",
    "The SimpleITK.ReadTransform() returns a SimpleITK.Transform . The content of the file can be any of the SimpleITK transformations or a composite (set of transformations). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Create a 2D rigid transformation, write it to disk and read it back.\n",
    "basic_transform = sitk.Euler2DTransform()\n",
    "basic_transform.SetTranslation((1.0,2.0))\n",
    "basic_transform.SetAngle(np.pi/2)\n",
    "\n",
    "full_file_name = os.path.join(OUTPUT_DIR, 'euler2D.tfm')\n",
    "\n",
    "sitk.WriteTransform(basic_transform, full_file_name)\n",
    "\n",
    "# The ReadTransform function returns an sitk.Transform no matter the type of the transform \n",
    "# found in the file (global, bounded, composite).\n",
    "read_result = sitk.ReadTransform(full_file_name)\n",
    "\n",
    "print('Different types: '+ str(type(read_result) != type(basic_transform)))\n",
    "util.print_transformation_differences(basic_transform, read_result)\n",
    "\n",
    "\n",
    "# Create a composite transform then write and read.\n",
    "displacement = sitk.DisplacementFieldTransform(2)\n",
    "field_size = [10,20]\n",
    "field_origin = [-10.0,-100.0]  \n",
    "field_spacing = [20.0/(field_size[0]-1),200.0/(field_size[1]-1)]   \n",
    "field_direction = [1,0,0,1] #direction cosine matrix (row major order)\n",
    "\n",
    "# Concatenate all the information into a single list.\n",
    "displacement.SetFixedParameters(field_size+field_origin+field_spacing+field_direction)\n",
    "displacement.SetParameters(np.random.random(len(displacement.GetParameters())))\n",
    "\n",
    "composite_transform = sitk.CompositeTransform([basic_transform,displacement])\n",
    "\n",
    "full_file_name = os.path.join(OUTPUT_DIR, 'composite.tfm')\n",
    "\n",
    "sitk.WriteTransform(composite_transform, full_file_name)\n",
    "read_result = sitk.ReadTransform(full_file_name)\n",
    "\n",
    "util.print_transformation_differences(composite_transform, read_result)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"02_images_and_resampling.ipynb\"><h2 align=right>Next &raquo;</h2></a>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
